# Copyright 2009-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the objectid module."""
from __future__ import annotations

import datetime
import pickle
import struct
import sys

sys.path[0:0] = [""]

from test import SkipTest, unittest
from test.utils_shared import oid_generated_on_process

from bson.errors import InvalidId
from bson.objectid import _MAX_COUNTER_VALUE, ObjectId
from bson.tz_util import FixedOffset, utc


def oid(x):
    return ObjectId()


class TestObjectId(unittest.TestCase):
    def test_creation(self):
        self.assertRaises(TypeError, ObjectId, 4)
        self.assertRaises(TypeError, ObjectId, 175.0)
        self.assertRaises(TypeError, ObjectId, {"test": 4})
        self.assertRaises(TypeError, ObjectId, ["something"])
        self.assertRaises(InvalidId, ObjectId, "")
        self.assertRaises(InvalidId, ObjectId, "12345678901")
        self.assertRaises(InvalidId, ObjectId, "1234567890123")
        self.assertTrue(ObjectId())
        self.assertTrue(ObjectId(b"123456789012"))
        a = ObjectId()
        self.assertTrue(ObjectId(a))

    def test_unicode(self):
        a = ObjectId()
        self.assertEqual(a, ObjectId(a))
        self.assertRaises(InvalidId, ObjectId, "hello")

    def test_from_hex(self):
        ObjectId("123456789012123456789012")
        self.assertRaises(InvalidId, ObjectId, "123456789012123456789G12")

    def test_repr_str(self):
        self.assertEqual(
            repr(ObjectId("1234567890abcdef12345678")), "ObjectId('1234567890abcdef12345678')"
        )
        self.assertEqual(str(ObjectId("1234567890abcdef12345678")), "1234567890abcdef12345678")
        self.assertEqual(str(ObjectId(b"123456789012")), "313233343536373839303132")
        self.assertEqual(
            ObjectId("1234567890abcdef12345678").binary, b"\x124Vx\x90\xab\xcd\xef\x124Vx"
        )
        self.assertEqual(
            str(ObjectId(b"\x124Vx\x90\xab\xcd\xef\x124Vx")), "1234567890abcdef12345678"
        )

    def test_equality(self):
        a = ObjectId()
        self.assertEqual(a, ObjectId(a))
        self.assertEqual(ObjectId(b"123456789012"), ObjectId(b"123456789012"))
        self.assertNotEqual(ObjectId(), ObjectId())
        self.assertNotEqual(ObjectId(b"123456789012"), b"123456789012")

        # Explicitly test inequality
        self.assertFalse(a != ObjectId(a))
        self.assertFalse(ObjectId(b"123456789012") != ObjectId(b"123456789012"))

    def test_binary_str_equivalence(self):
        a = ObjectId()
        self.assertEqual(a, ObjectId(a.binary))
        self.assertEqual(a, ObjectId(str(a)))

    def test_generation_time(self):
        d1 = datetime.datetime.now(tz=datetime.timezone.utc).replace(tzinfo=None)
        d2 = ObjectId().generation_time

        self.assertEqual(utc, d2.tzinfo)
        d2 = d2.replace(tzinfo=None)
        self.assertLess(d2 - d1, datetime.timedelta(seconds=2))

    def test_from_datetime(self):
        d = datetime.datetime.now(tz=datetime.timezone.utc).replace(tzinfo=None)
        d = d - datetime.timedelta(microseconds=d.microsecond)
        oid = ObjectId.from_datetime(d)
        self.assertEqual(d, oid.generation_time.replace(tzinfo=None))
        self.assertEqual("0" * 16, str(oid)[8:])

        aware = datetime.datetime(1993, 4, 4, 2, tzinfo=FixedOffset(555, "SomeZone"))
        offset = aware.utcoffset()
        assert offset is not None
        as_utc = (aware - offset).replace(tzinfo=utc)
        oid = ObjectId.from_datetime(aware)
        self.assertEqual(as_utc, oid.generation_time)

    def test_pickling(self):
        orig = ObjectId()
        for protocol in [0, 1, 2, -1]:
            pkl = pickle.dumps(orig, protocol=protocol)
            self.assertEqual(orig, pickle.loads(pkl))

    def test_pickle_backwards_compatability(self):
        # This string was generated by pickling an ObjectId in pymongo
        # version 1.9
        pickled_with_1_9 = (
            b"ccopy_reg\n_reconstructor\np0\n"
            b"(cbson.objectid\nObjectId\np1\nc__builtin__\n"
            b"object\np2\nNtp3\nRp4\n"
            b"(dp5\nS'_ObjectId__id'\np6\n"
            b"S'M\\x9afV\\x13v\\xc0\\x0b\\x88\\x00\\x00\\x00'\np7\nsb."
        )

        # We also test against a hardcoded "New" pickle format so that we
        # make sure we're backward compatible with the current version in
        # the future as well.
        pickled_with_1_10 = (
            b"ccopy_reg\n_reconstructor\np0\n"
            b"(cbson.objectid\nObjectId\np1\nc__builtin__\n"
            b"object\np2\nNtp3\nRp4\n"
            b"S'M\\x9afV\\x13v\\xc0\\x0b\\x88\\x00\\x00\\x00'\np5\nb."
        )

        # Have to load using 'latin-1' since these were pickled in python2.x.
        oid_1_9 = pickle.loads(pickled_with_1_9, encoding="latin-1")
        oid_1_10 = pickle.loads(pickled_with_1_10, encoding="latin-1")

        self.assertEqual(oid_1_9, ObjectId("4d9a66561376c00b88000000"))
        self.assertEqual(oid_1_9, oid_1_10)

    def test_random_bytes(self):
        self.assertTrue(oid_generated_on_process(ObjectId()))

    def test_is_valid(self):
        self.assertFalse(ObjectId.is_valid(None))
        self.assertFalse(ObjectId.is_valid(4))
        self.assertFalse(ObjectId.is_valid(175.0))
        self.assertFalse(ObjectId.is_valid({"test": 4}))
        self.assertFalse(ObjectId.is_valid(["something"]))
        self.assertFalse(ObjectId.is_valid(""))
        self.assertFalse(ObjectId.is_valid("12345678901"))
        self.assertFalse(ObjectId.is_valid("1234567890123"))

        self.assertTrue(ObjectId.is_valid(b"123456789012"))
        self.assertTrue(ObjectId.is_valid("123456789012123456789012"))

    def test_counter_overflow(self):
        # Spec-test to check counter overflows from max value to 0.
        ObjectId._inc = _MAX_COUNTER_VALUE
        ObjectId()
        self.assertEqual(ObjectId._inc, 0)

    def test_timestamp_values(self):
        # Spec-test to check timestamp field is interpreted correctly.
        TEST_DATA = {
            0x00000000: (1970, 1, 1, 0, 0, 0),
            0x7FFFFFFF: (2038, 1, 19, 3, 14, 7),
            0x80000000: (2038, 1, 19, 3, 14, 8),
            0xFFFFFFFF: (2106, 2, 7, 6, 28, 15),
        }

        def generate_objectid_with_timestamp(timestamp):
            oid = ObjectId()
            _, trailing_bytes = struct.unpack(">IQ", oid.binary)
            new_oid = struct.pack(">IQ", timestamp, trailing_bytes)
            return ObjectId(new_oid)

        for tstamp, exp_datetime_args in TEST_DATA.items():
            oid = generate_objectid_with_timestamp(tstamp)
            # 32-bit platforms may overflow in datetime.fromtimestamp.
            if tstamp > 0x7FFFFFFF and sys.maxsize < 2**32:
                try:
                    oid.generation_time
                except (OverflowError, ValueError):
                    continue
            self.assertEqual(oid.generation_time, datetime.datetime(*exp_datetime_args, tzinfo=utc))

    def test_random_regenerated_on_pid_change(self):
        # Test that change of pid triggers new random number generation.
        random_original = ObjectId._random()
        ObjectId._pid += 1
        random_new = ObjectId._random()
        self.assertNotEqual(random_original, random_new)


if __name__ == "__main__":
    unittest.main()
